package main

import (
	"archive/tar"
	"context"
	"encoding/gob"
	"encoding/json"
	"fmt"
	. "git.oschina.net/wkc/tar-fs/schema"
	"github.com/armon/go-radix"
	"github.com/golang/snappy"
	"github.com/namsral/flag"
	"github.com/urfave/negroni"
	"gopkg.in/cheggaaa/pb.v1"
	// "github.com/k0kubun/pp"
	"crypto/rand"
	"github.com/pilu/xrequestid"
	"github.com/pkg/errors"
	"github.com/prometheus/client_golang/prometheus"
	"github.com/zbindenren/negroni-prometheus"
	"golang.org/x/crypto/salsa20"
	"io"
	"io/ioutil"
	"net/http"
	_ "net/http/pprof"
	"os"
	"qiniupkg.com/x/log.v7"
	"strconv"
	"strings"
	// "sync"
	"crypto/sha256"
	"github.com/dustin/go-humanize"
	// "sync/atomic"
	"time"
)

func checkBlockSnappy() {
	f, err := os.Open("node_modules.tar")
	if err != nil {
		log.Panicln(err)
		return
	}
	src := make([]byte, 100*1024)
	dst := make([]byte, 100*1024)
	alls := 0
	alld := 0
	for {
		n, err := f.Read(src)
		if err == io.EOF {
			break
		}
		log.Panicln(err)
		s := src[:n]
		d := snappy.Encode(dst, s)
		fmt.Printf("%d -> %d %d%%\n", len(s), len(d), len(d)*100/len(s))
		alls += len(s)
		alld += len(d)
	}
	fmt.Printf("%d -> %d %d%%\n", alls, alld, alld*100/alls)
}

func checkEveryFileSnappy() {
	f, err := os.Open("node_modules.tar")
	if err != nil {
		log.Println(err)
		return
	}
	t := tar.NewReader(f)
	src := make([]byte, 100*1024)
	dst := make([]byte, 100*1024)
	alls := 0
	alld := 0
	for {
		h, err := t.Next()
		if err == io.EOF {
			break
		}
		if err != nil {
			log.Panicln(h, err)
			return
		}
		if h.Size == 0 {
			continue
		}
		for {
			n, err := t.Read(src)
			if err == io.EOF {
				break
			}
			if err != nil {
				log.Panicln(h, err)
				return
			}
			s := src[:n]
			d := snappy.Encode(dst, s)
			fmt.Printf("%d -> %d %d%%\n", len(s), len(d), len(d)*100/len(s))
			alls += len(s)
			alld += len(d)
		}
	}
	fmt.Printf("%d -> %d %d%%\n", alls, alld, alld*100/alls)
}

const BufSize = 1024 * 1024

func makeSnappyFile(srcFilename string, dstFilename string, dstIndexFilename string, maxSplitSize uint64, encrypt Encrypter) (err error) {
	srcFile, err := os.Open(srcFilename)
	if err != nil {
		return errors.Wrap(err, "open srcFilename")
	}
	defer srcFile.Close()

	var bar *pb.ProgressBar
	{
		info, err := srcFile.Stat()
		if err != nil {
			return errors.Wrap(err, "srcFile.Stat")
		}
		bar = pb.New64(info.Size()).SetUnits(pb.U_BYTES)
		bar.ShowTimeLeft = true
		bar.ShowSpeed = true
		bar.Start()
	}

	log.Info("writing", dstFilename)
	dstFile, err := os.OpenFile(dstFilename, os.O_CREATE|os.O_WRONLY, 0666)
	if err != nil {
		return errors.Wrap(err, "open dstFilename")
	}
	defer func() {
		dstFile.Close()
	}()

	log.Info("writing", dstIndexFilename)
	dstIndexFile, err := os.OpenFile(dstIndexFilename, os.O_CREATE|os.O_WRONLY, 0666)
	if err != nil {
		return errors.Wrap(err, "open dstIndexFilename")
	}
	defer dstIndexFile.Close()

	allFiles := make([]File, 0)

	srcFileTarStream := tar.NewReader(bar.NewProxyReader(srcFile))

	var dstFileSplitIndex, dstSeek, dstAllOffset int64

	srcBuf := make([]byte, BufSize)
	dstBuf := make([]byte, BufSize)

	for {
		h, err := srcFileTarStream.Next()
		if err == io.EOF {
			break
		}
		if err != nil {
			return errors.Wrapf(err, "read tar header error")
		}
		if h.Size == 0 {
			continue
		}
		hash := sha256.New()
		readTime := 0
		f := File{
			TarHeader: h,
			MultipleBlock: &MultipleBlock{
				BlockSize: BufSize,
			},
			FileIndex: dstFileSplitIndex,
		}
		if uint64(dstSeek+h.Size+1024) > maxSplitSize {
			dstFileSplitIndex++
			dstSeek = 0
			dstFile.Close()
			dstFilenameBase := strings.TrimSuffix(dstFilename, ".sf")
			name := fmt.Sprintf("%s.%d.sf", dstFilenameBase, dstFileSplitIndex)
			log.Info("writeing", name)
			dstFile, err = os.OpenFile(name, os.O_CREATE|os.O_WRONLY, 0666)
			if err != nil {
				return errors.Wrap(err, "open dstFilename")
			}
		}
		salt := encrypt.salt()
		for {
			n, err := srcFileTarStream.Read(srcBuf)
			if err == io.EOF {
				break
			}
			if err != nil {
				return errors.Wrapf(err, "read tar data error")
			}
			src := srcBuf[:n]
			hash.Write(src)
			d := snappy.Encode(dstBuf, src)
			encrypted := make([]byte, len(d))
			encrypt.encrypt(&encrypted, d, salt)
			n, err = dstFile.Write(encrypted)
			if err != nil {
				return errors.Wrapf(err, "write data")
			}
			if n != len(d) {
				log.Println("bad write")
			}
			f.MultipleBlock.CompressOffsets = append(f.MultipleBlock.CompressOffsets, dstSeek)
			dstSeek += int64(n)
			dstAllOffset += int64(n)
			f.MultipleBlock.CompressSizes = append(f.MultipleBlock.CompressSizes, int64(n))
			readTime++
		}
		if readTime == 1 {
			f.Compress = 1
			f.CompressOffset = f.MultipleBlock.CompressOffsets[0]
			f.CompressSize = f.MultipleBlock.CompressSizes[0]
			f.MultipleBlock = nil
		} else if readTime > 1 {
			f.Compress = 2
			f.CompressOffset = -1
			f.CompressSize = -1
		} else {
			log.Panicln("bad file", h.Name)
		}
		f.Encrypt = &Encrypt{Salt: salt, Cipher: encrypt.cipher()}
		f.DataSha256 = hash.Sum(nil)
		allFiles = append(allFiles, f)
	}
	{

		srcInfo, err := srcFile.Stat()
		if err != nil {
			return errors.Wrapf(err, "srcFile.Stat")
		}
		fmt.Printf("compress %.2f\n", float64(dstAllOffset*100)/float64((srcInfo.Size())))
		fmt.Printf("file num %d\n", len(allFiles))
	}
	dstIndexFileSnappy := snappy.NewWriter(dstIndexFile)
	err = gob.NewEncoder(dstIndexFileSnappy).Encode(IndexMetadataFile{
		Files:   allFiles,
		Version: "1",
	})
	if err != nil {
		return errors.Wrap(err, "write index file")
	}
	dstIndexFileSnappy.Close()
	return nil
}

type Tree struct {
	tree         *radix.Tree
	maxFileIndex int64
}

func (t *Tree) Get(s string) (f *File, ok bool) {
	v, found := t.tree.Get(s)
	if found {
		r := v.(File)
		return &r, true
	}
	return nil, false
}

func (t *Tree) Prefix(ctx context.Context, s string) (fs chan File) {
	fs = make(chan File)
	go func() {
		t.tree.WalkPrefix(s, func(s string, v interface{}) bool {
			// time.Sleep(time.Second / 100)
			select {
			case <-ctx.Done():
				return true
			default:
			}
			fs <- v.(File)
			return false
		})
		close(fs)
	}()
	return
}

func NewTree(dstIndexFilename string) (*Tree, error) {
	dstIndexFile, err := os.Open(dstIndexFilename)
	if err != nil {
		return nil, errors.Wrap(err, "open dstFilename")
	}
	defer dstIndexFile.Close()
	m := IndexMetadataFile{}
	err = gob.NewDecoder(snappy.NewReader(dstIndexFile)).Decode(&m)
	if err != nil {
		return nil, errors.Wrap(err, "write index file")
	}
	tree := radix.New()
	t := Tree{tree: tree}
	for _, file := range m.Files {
		tree.Insert(file.TarHeader.Name, file)
		if file.FileIndex > t.maxFileIndex {
			t.maxFileIndex = file.FileIndex
		}
	}
	return &t, nil
}

type Server struct {
	tree        *Tree
	dataFiles   []*os.File
	fromHTTPURL string
	encrypt     Encrypter
}

func (s *Server) close() {
	for _, f := range s.dataFiles {
		if f != nil {
			f.Close()
		}
	}
}

func (s *Server) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
	q := req.URL.Query()
	if q.Get("prefix") != "" {
		ctx, cancel := context.WithCancel(context.Background())
		defer cancel()
		fs := s.tree.Prefix(ctx, q.Get("prefix"))
		verbose := q.Get("verbose") == "true"
		for f := range fs {
			if verbose {
				var buf []byte
				buf, _ = json.Marshal(f)
				_, err := resp.Write(buf)
				if err != nil {
					cancel()
				}
			} else {
				_, err := resp.Write([]byte(f.TarHeader.Name))
				if err != nil {
					cancel()
				}
			}
			_, err := resp.Write([]byte{'\n'})
			if err != nil {
				cancel()
			}
		}
		return
	}
	fileName := req.URL.Path[1:]
	// log.Debug("req ", fileName)
	h, ok := s.tree.Get(fileName)
	if !ok || h.TarHeader == nil {
		http.Error(resp, "404 file not found", 404)
		return
	}
	var readAt func(buf []byte, offset int64) (int, error)
	if s.fromHTTPURL != "" {
		readAt = func(buf []byte, offset int64) (int, error) {
			url := ""
			if h.FileIndex == 0 {
				url = s.fromHTTPURL + ".sf"
			} else {
				url = fmt.Sprintf("%s.%d.sf", s.fromHTTPURL, h.FileIndex)
			}
			req, err := http.NewRequest("GET", url, nil)
			if err != nil {
				return 0, err
			}
			req.Header.Add("Range", fmt.Sprintf("bytes=%d-%d", offset, int64(len(buf))+offset-1))
			resp, err := http.DefaultClient.Do(req)
			if err != nil {
				return 0, errors.Wrap(err, "http req do")
			}
			body, err := ioutil.ReadAll(resp.Body)
			if err != nil {
				return 0, errors.Wrap(err, "read http body")
			}
			if len(body) != len(buf) {
				return 0, errors.New("readed http body length error")
			}
			copy(buf, body)
			return len(buf), nil
		}
	} else {
		if int(h.FileIndex) >= len(s.dataFiles) {
			log.Error("datafile index overflow")
			resp.WriteHeader(513)
			return
		}
		fd := s.dataFiles[h.FileIndex]
		if fd == nil {
			log.Error("datafile not found, index:", h.FileIndex, fileName)
			resp.WriteHeader(520)
			return
		}
		readAt = func(buf []byte, offset int64) (int, error) {
			return fd.ReadAt(buf, offset)
		}
	}

	if h.Compress == 1 {
		src := make([]byte, h.CompressSize)
		dst := make([]byte, h.TarHeader.Size)
		resp.Header().Add("content-length", strconv.FormatInt(h.TarHeader.Size, 10))

		n, err := readAt(src, h.CompressOffset)
		encrypted := src[:n]
		data := make([]byte, len(encrypted))
		s.encrypt.decrypt(&data, encrypted, h.Encrypt.Salt)
		if err != nil || int64(n) != h.CompressSize {
			log.Error("dataFile Read", err, fileName)
			resp.WriteHeader(510)
			return
		}

		data, err = snappy.Decode(dst, data)
		if err != nil || int64(len(data)) != h.TarHeader.Size {
			log.Error("dataFile Decode", err, fileName)
			resp.WriteHeader(511)
			return
		}
		resp.Write(data)
	} else {
		if h.MultipleBlock == nil {
			log.Error("bad file, MultipleBlock nil", fileName)
			resp.WriteHeader(512)
			return
		}
		src := make([]byte, h.MultipleBlock.BlockSize)
		dst := make([]byte, h.MultipleBlock.BlockSize)
		resp.Header().Add("content-length", strconv.FormatInt(h.TarHeader.Size, 10))
		l := len(h.MultipleBlock.CompressOffsets)
		for i := 0; i < l; i++ {
			offset := h.MultipleBlock.CompressOffsets[i]
			size := h.MultipleBlock.CompressSizes[i]
			n, err := readAt(src[:size], offset)
			if err != nil || int64(n) != size {
				log.Error("dataFile Read", err, fileName, n, size)
				if i == 0 {
					resp.WriteHeader(510)
				}
				return
			}
			encrypted := src[:size]
			data := make([]byte, len(encrypted))
			s.encrypt.decrypt(&data, encrypted, h.Encrypt.Salt)
			data, err = snappy.Decode(dst, data)
			sizeMatched := false
			if i == l-1 {
				if int64(len(data)) == h.TarHeader.Size%h.MultipleBlock.BlockSize {
					sizeMatched = true
				}
			} else {
				if int64(len(data)) == h.MultipleBlock.BlockSize {
					sizeMatched = true
				}
			}
			if err != nil || !sizeMatched {
				log.Error("dataFile Decode", err, fileName)
				if i == 0 {
					resp.WriteHeader(510)
				}
				return
			}
			resp.Write(data)
		}
	}
}

func newServer(dataFilename string, indexFilename string, allowedMissing bool, fromHTTPURL string, encrypt Encrypter) (*Server, error) {
	s := new(Server)
	s.encrypt = encrypt
	s.fromHTTPURL = fromHTTPURL
	tree, err := NewTree(indexFilename)
	if err != nil {
		return nil, errors.Wrap(err, "NewTree")
	}
	s.tree = tree
	var i int64
	if fromHTTPURL == "" {
		for ; i <= s.tree.maxFileIndex; i++ {
			name := dataFilename
			if i != 0 {
				name = fmt.Sprintf("%s.%d.sf", strings.TrimSuffix(name, ".sf"), i)
				// TODO
			}
			log.Info("open", name)
			f, err := os.Open(name)
			if err != nil {
				if allowedMissing {
					log.Warn("miss file", name)
				} else {
					return nil, errors.Wrap(err, "open dataFilename")
				}
			}
			s.dataFiles = append(s.dataFiles, f)
		}
	}
	return s, nil
}

func ls(indexFilename string) error {
	tree, err := NewTree(indexFilename)
	if err != nil {
		return errors.Wrap(err, "NewTree")
	}
	ctx := context.Background()
	for v := range tree.Prefix(ctx, "") {
		fmt.Println(v.TarHeader.Name)
	}
	return nil
}

func main() {
	// {
	// 	log.SetOutputLevel(0)
	// 	t, err := NewTree("node_modules.tar.sfi")
	// 	if err != nil {
	// 		log.Panicf("%+v", err)
	// 	}
	// 	pp.Println(t.Get("node_modules/when/es6-shim/Promise.js"))
	// 	ctx, cancel := context.WithTimeout(context.Background(), time.Second)
	// 	defer cancel()
	// 	for f := range t.Prefix(ctx, "node_modules/") {
	// 		log.Info(f.TarHeader.Name)
	// 	}
	// 	return
	// }

	var param = struct {
		server           bool
		gen              bool
		tarFile          string
		port             string
		maxSplitSize     string
		maxSplitSizeByte uint64
		allowedMissing   bool
		fromHTTPURL      string
		crypt            string
		key              string
		ls               bool
	}{}
	flag.BoolVar(&param.server, "server", false, "")
	flag.BoolVar(&param.gen, "gen", false, "")
	flag.BoolVar(&param.allowedMissing, "allowedMissing", false, "")
	flag.BoolVar(&param.ls, "ls", false, "")
	flag.StringVar(&param.tarFile, "tarFile", "1.tar", "")
	flag.StringVar(&param.port, "port", ":7766", "")
	flag.StringVar(&param.fromHTTPURL, "fromHTTPURL", "", "")
	flag.StringVar(&param.maxSplitSize, "maxSplitSize", "1GB", "")
	flag.StringVar(&param.crypt, "crypt", "", "salsa20")
	flag.StringVar(&param.key, "key", "", "")
	flag.Parse()

	{
		b, err := humanize.ParseBytes(param.maxSplitSize)
		if err != nil {
			log.Panic(err)
		}
		param.maxSplitSizeByte = b
	}
	var encrypt Encrypter
	{
		switch param.crypt {
		case "":
			encrypt = NewEncryptNoop(param.key)
		case "salsa20":
			if param.key == "" {
				log.Panic("key should not empty")
			}
			encrypt = NewEncryptSalsa20(param.key)
		default:
			log.Panicln("bad crypt")
		}
	}

	switch {
	case param.ls:
		err := ls(param.tarFile + ".sfi")
		if err != nil {
			log.Panicf("%+v", err)
		}
		return
	case param.gen:
		err := makeSnappyFile(param.tarFile, param.tarFile+".sf", param.tarFile+".sfi", param.maxSplitSizeByte, encrypt)
		if err != nil {
			log.Panicf("%+v", err)
		}
		return
	case param.server:
		log.SetOutputLevel(0)
		start := time.Now()
		mux, err := newServer(param.tarFile+".sf", param.tarFile+".sfi", param.allowedMissing, param.fromHTTPURL, encrypt)
		log.Info("load use ", time.Since(start))
		if err != nil {
			log.Panicf("%+v", err)
		}
		n := negroni.Classic()
		n.Use(xrequestid.New(16))
		n.Use(negroniprometheus.NewMiddleware("tarfs"))
		r := http.NewServeMux()
		r.Handle("/metrics", prometheus.Handler())
		r.Handle("/", mux)
		n.UseHandler(r)
		if !strings.Contains(param.port, ":") {
			param.port = "127.0.0.1:" + param.port
		}
		n.Run(param.port)
		return
	}
	flag.Usage()
}

type Encrypter interface {
	encrypt(out *[]byte, in []byte, salt []byte)
	decrypt(out *[]byte, in []byte, salt []byte)
	salt() []byte
	cipher() int
}

type encryptSalsa20 struct {
	key [32]byte
}

func NewEncryptSalsa20(key string) (e *encryptSalsa20) {
	e = new(encryptSalsa20)
	k := []byte(key)
	for i := 0; i < 32; i++ {
		if i == len(k) {
			break
		}
		e.key[i] = k[i]
	}
	return
}

func (e *encryptSalsa20) encrypt(out *[]byte, in []byte, salt []byte) {
	if len(in) != len(*out) {
		panic("bad buf")
	}
	salsa20.XORKeyStream(*out, in, salt, &e.key)
	return
}

func (e *encryptSalsa20) decrypt(out *[]byte, in []byte, salt []byte) {
	if len(in) != len(*out) {
		panic("bad buf")
	}
	salsa20.XORKeyStream(*out, in, salt, &e.key)
	return
}

func (e *encryptSalsa20) salt() (salt []byte) {
	salt = make([]byte, 8)
	_, err := rand.Read(salt)
	if err != nil {
		panic(err)
	}
	return
}

func (e *encryptSalsa20) cipher() int {
	return 1
}

type encryptNoop struct {
}

func NewEncryptNoop(key string) (e *encryptNoop) {
	return &encryptNoop{}
}

func (e *encryptNoop) encrypt(out *[]byte, in []byte, salt []byte) {
	copy(*out, in)
	return
}

func (e *encryptNoop) decrypt(out *[]byte, in []byte, salt []byte) {
	copy(*out, in)
	return
}
func (e *encryptNoop) salt() (salt []byte) {
	return
}

func (e *encryptNoop) cipher() int {
	return 0
}

/* TODO:
- 支持不同的压缩算法
- 自动跳过压缩
- 支持http range
- 支持hash验证，http etag
- mmap
- etag
*/
